{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04_batchnorm_simple.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/mlelarge/dataflowr/blob/master/Notebooks/04_deeper/04_batchnorm_simple_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "metadata": {
        "id": "gaR5P442yOfy",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# BatchNorm intuitions"
      ]
    },
    {
      "metadata": {
        "id": "NToGasMiyOf1",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "%matplotlib inline\n",
        "# %matplotlib notebook\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import torch\n",
        "from torch import nn\n",
        "import torch.utils.data as Data\n",
        "import torch.nn.functional as F\n",
        "\n",
        "torch.manual_seed(42)\n",
        "np.random.seed(42)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "gDNfIkj9yOf5",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## 1. Impact on activations and results"
      ]
    },
    {
      "metadata": {
        "id": "vep1e8KfyOf5",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "class ExperimentParams():\n",
        "    def __init__(self):\n",
        "        self.num_samples = 2000\n",
        "        self.batch_size = 64\n",
        "        self.lr = 3e-2\n",
        "        self.n_hidden = 8\n",
        "        self.num_epochs = 12\n",
        "        self.num_workers = 4\n",
        "        \n",
        "        self.activation = F.tanh\n",
        "        self.bias_init = -0.2\n",
        "        \n",
        "        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
        "        self.data_dir = '/home/docker_user/'\n",
        "        \n",
        "\n",
        "args = ExperimentParams()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "W-uaAatJyOf7",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# training data\n",
        "x = np.linspace(-7, 10, args.num_samples)[:, np.newaxis]\n",
        "noise = np.random.normal(0, 2, x.shape)\n",
        "y = np.square(x) - 5 + noise\n",
        "\n",
        "# test data\n",
        "test_x = np.linspace(-7, 10, 200)[:, np.newaxis]\n",
        "noise = np.random.normal(0, 2, test_x.shape)\n",
        "test_y = np.square(test_x) - 5 + noise\n",
        "\n",
        "train_x, train_y = torch.from_numpy(x).float(), torch.from_numpy(y).float()\n",
        "test_x, test_y = torch.from_numpy(test_x).float(), torch.from_numpy(test_y).float()\n",
        "\n",
        "train_dataset = Data.TensorDataset(train_x, train_y)\n",
        "train_loader = Data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2,)\n",
        "\n",
        "# show data\n",
        "plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train')\n",
        "plt.legend(loc='upper left')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "hqpXdVuZyOf-",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "class Net(nn.Module):\n",
        "    def __init__(self, n_hidden,batch_normalization=False, bias_init= -0.2):\n",
        "        super(Net, self).__init__()\n",
        "        self.n_hidden = n_hidden\n",
        "        self.do_bn = batch_normalization\n",
        "        self.fcs = []\n",
        "        self.bns = []\n",
        "        self.bn_input = nn.BatchNorm1d(1, momentum=0.5)  \n",
        "        self.activation = nn.Tanh()\n",
        "        self.bias_init = bias_init \n",
        "        \n",
        "        for i in range(self.n_hidden):               # build hidden layers and BN layers\n",
        "            input_size = 1 if i == 0 else 10\n",
        "            fc = nn.Linear(input_size, 10)\n",
        "            setattr(self, 'fc%i' % i, fc)       # IMPORTANT set layer to the Module\n",
        "            self._set_init(fc)                  # parameters initialization\n",
        "            self.fcs.append(fc)\n",
        "            if self.do_bn:\n",
        "                bn = nn.BatchNorm1d(10, momentum=0.5)\n",
        "                setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Module\n",
        "                self.bns.append(bn)\n",
        "\n",
        "        self.predict = nn.Linear(10, 1)         # output layer\n",
        "        self._set_init(self.predict)            # parameters initialization\n",
        "\n",
        "    def _set_init(self, layer):\n",
        "        nn.init.normal_(layer.weight, mean=0., std=.1)\n",
        "        nn.init.constant_(layer.bias, self.bias_init)\n",
        "\n",
        "    def forward(self, x):\n",
        "        pre_activation = [x]\n",
        "        if self.do_bn: \n",
        "            x = self.bn_input(x)     # input batch normalization\n",
        "        \n",
        "        layer_input = [x]\n",
        "        for i in range(self.n_hidden):\n",
        "            x = self.fcs[i](x)\n",
        "            pre_activation.append(x)\n",
        "            if self.do_bn: \n",
        "                x = self.bns[i](x)   # batch normalization\n",
        "            x = self.activation(x)\n",
        "            layer_input.append(x)\n",
        "        \n",
        "        out = self.predict(x)\n",
        "        return out, layer_input, pre_activation"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "NyKj1eggyOgA",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "nets = [Net(n_hidden=args.n_hidden, batch_normalization=False), Net(n_hidden=args.n_hidden, batch_normalization=True)]\n",
        "\n",
        "print(*nets)    # print net architecture"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "yCbHEUdPyOgD",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "optimizers = [torch.optim.Adam(net.parameters(), lr=args.lr) for net in nets]\n",
        "\n",
        "loss_fn = torch.nn.MSELoss()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "UuwLYxaxyOgG",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):\n",
        "    for i, (ax_pa, ax_pa_bn, ax,  ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):\n",
        "        [a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]\n",
        "        if i == 0: \n",
        "            p_range = (-7, 10)\n",
        "            the_range = (-7, 10)\n",
        "        else:\n",
        "            p_range = (-4, 4)\n",
        "            the_range = (-1, 1)\n",
        "            \n",
        "        ax_pa.set_title('L' + str(i))\n",
        "        ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='orange', alpha=0.5);\n",
        "        ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='green', alpha=0.5)\n",
        "        ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='orange');\n",
        "        ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='green')\n",
        "        \n",
        "        for a in [ax_pa, ax, ax_pa_bn, ax_bn]: \n",
        "            a.set_yticks(())\n",
        "            a.set_xticks(())\n",
        "       \n",
        "        ax_pa_bn.set_xticks(p_range)\n",
        "        ax_bn.set_xticks(the_range)\n",
        "        axs[0, 0].set_ylabel('PreAct');\n",
        "        axs[1, 0].set_ylabel('BN PreAct');\n",
        "        axs[2, 0].set_ylabel('Act');\n",
        "        axs[3, 0].set_ylabel('BN Act')\n",
        "    plt.pause(0.01)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "xbC1DNmpyOgJ",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# training\n",
        "losses = [[], []]  # record loss for two networks\n",
        "for epoch in range(args.num_epochs):\n",
        "    print('Epoch: ', epoch)\n",
        "    layer_inputs, pre_acts = [], []\n",
        "    for net, l in zip(nets, losses):\n",
        "        net.eval()              # set eval mode to fix moving_mean and moving_var\n",
        "        pred, layer_input, pre_act = net(test_x)\n",
        "        l.append(loss_fn(pred, test_y).item())\n",
        "        layer_inputs.append(layer_input)\n",
        "        pre_acts.append(pre_act)\n",
        "        net.train()             # free moving_mean and moving_var\n",
        "\n",
        "        \n",
        "    f, axs = plt.subplots(4, args.n_hidden+1, figsize=(10, 5))\n",
        "    plot_histogram(*layer_inputs, *pre_acts)     # plot histogram\n",
        "\n",
        "    for step, (b_x, b_y) in enumerate(train_loader):\n",
        "        for net, opt in zip(nets, optimizers):     # train for each network\n",
        "            pred, _, _ = net(b_x)\n",
        "            loss = loss_fn(pred, b_y)\n",
        "            opt.zero_grad()\n",
        "            loss.backward()\n",
        "            opt.step()    # it will also learns the parameters in Batch Normalization\n",
        "            "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "5kQXQMVKyOgL",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# plot training loss\n",
        "plt.figure(2)\n",
        "plt.plot(losses[0], c='orange', lw=3, label='Original')\n",
        "plt.plot(losses[1], c='green', lw=3, label='Batch Normalization')\n",
        "plt.xlabel('step');\n",
        "plt.ylabel('test loss');\n",
        "plt.ylim((0, 2000));\n",
        "plt.legend(loc='best')\n",
        "\n",
        "# evaluation\n",
        "# set net to eval mode to freeze the parameters in batch normalization layers\n",
        "[net.eval() for net in nets]    # set eval mode to fix moving_mean and moving_var\n",
        "preds = [net(test_x)[0] for net in nets]\n",
        "plt.figure(3)\n",
        "plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='orange', lw=4, label='Original')\n",
        "plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='green', lw=4, label='Batch Normalization')\n",
        "plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')\n",
        "plt.legend(loc='best')\n",
        "plt.show()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "M9ClPalzyOgN",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "## 2. Dependency on weight initialization"
      ]
    },
    {
      "metadata": {
        "id": "UEUgVM9hyOgO",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "We will generate a toy dataset using `scikit-learn`"
      ]
    },
    {
      "metadata": {
        "id": "QOibg6nayOgP",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "from sklearn import datasets"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "QAg24Il8yOgR",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "x, y = datasets.make_circles(n_samples=args.num_samples, factor=.5, noise=.05)\n",
        "train_x, train_y = torch.from_numpy(x).float(), torch.from_numpy(y).long()\n",
        "\n",
        "# test data\n",
        "test_x, test_y = datasets.make_circles(n_samples=200, factor=.5, noise=.05)\n",
        "test_x, test_y = torch.from_numpy(test_x).float(), torch.from_numpy(test_y).long()\n",
        "\n",
        "\n",
        "train_dataset = Data.TensorDataset(train_x, train_y)\n",
        "train_loader = Data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2,)\n",
        "\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "9p5k31SjyOgT",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "\n",
        "plt.figure(figsize=(10,10))\n",
        "colors = ['orange', 'green']\n",
        "for i in range(2):\n",
        "    inds = np.where(y==i)[0]\n",
        "    plt.scatter(x[inds,0], x[inds,1], alpha=0.5, color=colors[i])\n",
        "\n",
        "plt.legend(['class 0', 'class 1'])\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "QFY32UjqyOgW",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "This is a script for creating a model on the fly"
      ]
    },
    {
      "metadata": {
        "id": "Vv_kNKaEyOgW",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def create_model(with_batchnorm, nc = 32, depth = 16):\n",
        "    modules = []\n",
        "    modules.append(nn.Linear(2, nc))\n",
        "    if with_batchnorm: \n",
        "        modules.append(nn.BatchNorm1d(nc))\n",
        "    \n",
        "    modules.append(nn.ReLU())\n",
        "    for d in range(depth):\n",
        "        modules.append(nn.Linear(nc, nc))\n",
        "        if with_batchnorm: \n",
        "            modules.append(nn.BatchNorm1d(nc))\n",
        "        modules.append(nn.ReLU())\n",
        "    \n",
        "    modules.append(nn.Linear(nc, 2))\n",
        "    \n",
        "    return nn.Sequential(*modules)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "hWQa3xbVyOgY",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "stds = [1e-3, 1e-2, 1e-1, 1e0, 1e1]\n",
        "args.num_epochs = 10\n",
        "\n",
        "accuracies = [[], []]\n",
        "losses = [[], []]  # record test loss for two networks\n",
        "for std in stds:\n",
        "    print(f'Initializing nets with std: {std}')\n",
        "    nets = [create_model(with_batchnorm=False), create_model(with_batchnorm=True)]\n",
        "    \n",
        "    optimizers = [torch.optim.Adam(net.parameters(), lr=args.lr) for net in nets]\n",
        "\n",
        "    loss_fn = torch.nn.CrossEntropyLoss()\n",
        "    for net in nets:\n",
        "        with torch.no_grad():\n",
        "            for p in net.parameters(): p.normal_(0, std)\n",
        "\n",
        "    net.train()\n",
        "    for epoch in range(args.num_epochs):\n",
        "        for step, (b_x, b_y) in enumerate(train_loader):\n",
        "            for net, opt in zip(nets, optimizers):     # train for each network\n",
        "                output = net(b_x)\n",
        "                loss = loss_fn(output, b_y)\n",
        "                opt.zero_grad()\n",
        "                loss.backward()\n",
        "                opt.step()    # it will also learns the parameters in Batch Normalization\n",
        "\n",
        "    for net, l, acc in zip(nets, losses, accuracies):\n",
        "        net.eval()\n",
        "        output = net(test_x)\n",
        "        l.append(loss_fn(output, test_y).item())\n",
        "        _,pred = torch.max(output.data,1)\n",
        "        corrects = torch.sum(pred == test_y).float()/test_y.numel()\n",
        "        acc.append(corrects.item())\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "2vgQSRMwyOgb",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# plotting\n",
        "plt.cla()\n",
        "plt.plot(stds, losses[0], 'orange', lw=3, label='no batchnorm')\n",
        "plt.plot(stds, losses[1], 'green', lw=3, label='with batchnorm')\n",
        "plt.legend(loc='upper left'); \n",
        "plt.xlabel('weight std')\n",
        "plt.ylabel('test loss')\n",
        "plt.title('Test loss')\n",
        "plt.grid(True)\n",
        "plt.pause(0.1)\n",
        "\n",
        "plt.show()\n",
        "        \n",
        "plt.cla()\n",
        "plt.plot(stds, accuracies[0], 'orange', lw=3, label='no batchnorm')\n",
        "plt.plot(stds, accuracies[1], 'green', lw=3, label='with batchnorm')\n",
        "plt.legend(loc='upper left'); \n",
        "plt.xlabel('weight std')\n",
        "plt.ylabel('test accuracy')\n",
        "plt.title('Test accuracy')\n",
        "plt.grid(True)\n",
        "plt.pause(0.1)\n",
        "plt.show()\n",
        "        "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "EMYtBsayyOgd",
        "colab_type": "text"
      },
      "cell_type": "markdown",
      "source": [
        "# 3. Exercises\n",
        "\n",
        "1. Code a similar pipeline from Dropout, with conv layers, comparing variant with and without BatchNorm. What do you notice?\n",
        "2. Try out larger learning rate values. Plot the decrease in training and test error.\n",
        "3. Do we need Batch Normalization in every layer? Experiment with it?"
      ]
    }
  ]
}